#! /usr/bin/env python
#coding=utf-8

import string
import sys
import os
import sqlite3

from tablebuilder import SymbolsTableBuilder
from tablebuilder import UndefinedSymbolsTableBuilder
from tablebuilder import CallsTableBuilder
from tablebuilder import DepedenceiesTableBuilder
from tablebuilder import IndirectsTableBuilder
from tablebuilder import ModuleTableBuilder
from tablebuilder import SymbolNamesTableBuilder

import elf_walker
import elf_file
import elf_file_mgr
from mem import MemDbBuilder
from startup_config import config_parser
from parameter import parameter_parser

class ElfFileWithDbInfo(elf_file_mgr.ElfFileWithDepsInfo):
    def __init__(self, file, prefix):
        super(ElfFileWithDbInfo, self).__init__(file, prefix)

    def addToDb(self, builder, hasMemInfo, execute=True):
        self.updateCountsInfo(builder.getCursor(), hasMemInfo)
        return builder.addObjectToDb(self, execute)

    def addSymbolsToDb(self, builder):
        cnt = 0
        script = "BEGIN TRANSACTION;\n"
        for symbol in self.provided_symbols():
            symbol["parent_id"] = self["id"]
            sqlcmd = builder.addObjectToDb(symbol, False)
            script = script + sqlcmd + ";\n"
            cnt = cnt + 1
        script = script + "COMMIT;\n"
        #print(script)
        cursor = builder.getCursor()
        cursor.executescript(script)
        return cnt

    def addUndefinesToDb(self, builder):
        cnt = 0
        script = "BEGIN TRANSACTION;\n"
        for symbol in self.undefined_symbols():
            symbol["parent_id"] = self["id"]
            sqlcmd = builder.addObjectToDb(symbol, False)
            script = script + sqlcmd + ";\n"
            cnt = cnt + 1
        script = script + "COMMIT;\n"
        #print(script)
        cursor = builder.getCursor()
        cursor.executescript(script)
        return cnt

    def addCallsToDb(self, cursor):
        cnt = 0
        for dep in self["deps"]:
            cnt = cnt + dep.addMatchedCallsToDb(cursor)

        print("    %s add %d matched calls in %d deps" % (self["name"], cnt, len(self["deps"])))
        return cnt

    # All provided symbols
    def __get_provided_symbols_cnt(self, cursor):
        if "provided" in self:
            return self["provided"]
        try:
            sqlcmd = "select count(id) from symbols where parent_id=%d" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    # Provided Symbols being used
    def __get_used_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(distinct(symbol_id)) from calls where callee_id=%d" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    # Get needed symbols cnt
    def __get_needed_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(name) from undefines where parent_id=%d" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    # Undefined symbols being provided by depended libraries
    def __get_matched_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(distinct(symbol_id)) from calls where caller_id=%d" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    # Undefined symbols duplicated by depended libraries
    def __get_duplicated_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(name) from call_details where caller_id=%d GROUP by name HAVING count(*)  > 1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    # Undefined symbols being provided by depended libraries
    def __get_unmatched_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(name) from undefines where parent_id=%d and name not in (select distinct(name) from call_details where caller_id=%d)" % (self["id"], self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def updateCountsInfo(self, cursor, hasMemInfo):
        self["needed"] = self.__get_needed_symbols_cnt(cursor)
        self["matched"] = self.__get_matched_symbols_cnt(cursor)
        self["duplicated"] = self.__get_duplicated_symbols_cnt(cursor)
        self["unmatched"] = self.__get_unmatched_symbols_cnt(cursor)

        self["provided"] = self.__get_provided_symbols_cnt(cursor)
        self["used"] = self.__get_used_symbols_cnt(cursor)

        self._updateMemInfo(cursor, hasMemInfo)
        self._updateChipsetSdkInfo(cursor)

    def _updateMemInfo(self, cursor, hasMemInfo):
        self["object_id"] = 0
        cols = ModuleTableBuilder.MEMINFO_COLS
        for c in cols:
            self[c] = 0
        if not hasMemInfo:
            return
        sqlcmd = "select id, %s from objects where name='%s'" % (", ".join(cols), self["path"])
        #print(sqlcmd)
        cursor.execute(sqlcmd)
        for row in cursor:
            #print(row)
            for idx, val in enumerate(row):
                if idx == 0:
                    self["object_id"] = val
                elif val:
                    self[cols[idx - 1]] = val
            break

    def __get_chipsetsdk_dependedBy_cnt(self, cursor):
        try:
            sqlcmd = "select count(id) from dependencies where callee_id=%d and chipsetsdk=1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def __get_chipsetsdk_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(distinct(symbol_id)) from call_details where callee_id=%d and chipsetsdk=1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def __get_platformsdk_dependedBy_cnt(self, cursor):
        try:
            sqlcmd = "select count(id) from dependencies where callee_id=%d and platformsdk=1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def __get_platformsdk_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(distinct(symbol_id)) from call_details where callee_id=%d and platformsdk=1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def __get_external_symbols_cnt(self, cursor):
        try:
            sqlcmd = "select count(distinct(symbol_id)) from call_details where callee_id=%d and external=1" % (self["id"])
            cursor.execute(sqlcmd)
            for row in cursor:
                return int(row[0])
        except:
            return 0

    def _updateChipsetSdkInfo(self, cursor):
        self["chipsetsdk_DependedBy"] = self.__get_chipsetsdk_dependedBy_cnt(cursor)
        self["chipsetsdk_symbols"] = self.__get_chipsetsdk_symbols_cnt(cursor)
        self["platformsdk_DependedBy"] = self.__get_platformsdk_dependedBy_cnt(cursor)
        self["platformsdk_symbols"] = self.__get_platformsdk_symbols_cnt(cursor)
        self["external_symbols"] = self.__get_external_symbols_cnt(cursor)

class DependencyDB(elf_file_mgr.Dependency):
    def __init__(self, idx, caller, callee):
        super(DependencyDB, self).__init__(idx, caller, callee)

    def addToDb(self, builder, execute=True):
        self._updateCallsFromDb(builder.getCursor())
        return builder.addObjectToDb(self, execute)

    def _updateCallsFromDb(self, cursor):
        if self["calls"] > 0:
            return

        sqlcmd = "select count(symbol_id) from calls where dependence_id=%d" % (self["id"])
        cursor.execute(sqlcmd)
        for row in cursor:
            self["calls"] = row[0]
            return

    def addMatchedCallsToDb(self, cursor):
        cnt = 0
        script = "BEGIN TRANSACTION;\n"
        querycmd = "select id from symbols where parent_id=%d and name in (select name from undefines where parent_id=%d)" % (self["callee_id"], self["caller_id"])
        cursor.execute(querycmd)
        for row in cursor:
            sqlcmd = 'insert into calls(caller_id, callee_id, symbol_id, dependence_id) values (%d,%d,%d,%d)' % \
                    (self["caller_id"], self["callee_id"], row[0], self["id"])
            script = script + sqlcmd + ";\n"
            cnt = cnt + 1
        script = script + "COMMIT;\n"
        cursor.executescript(script)

        self["calls"] = cnt
        return cnt

class IndirectDep(dict):
    def __init__(self, idx, caller, callee):
        self["id"] = idx
        self["caller_id"] = caller["id"]
        self["callee_id"] = callee["id"]
        self["calls"] = 0
        self["external"] = caller["componentName"] != callee["componentName"]

class ArchInfoCollector(elf_file_mgr.ElfFileMgr):
    def __init__(self, args, elfFileClass=None, dependenceClass = None):
        super(ArchInfoCollector, self).__init__(args.input, ElfFileWithDbInfo, DependencyDB)

        if not args.output:
            args.output = args.input

        if not os.path.exists(args.output):
            os.makedirs(args.output)

        self._conn = sqlite3.connect(os.path.join(args.output, "archinfo.db"))
        self._cursor = self._conn.cursor()

        if args.mem:            # Update memory information only
            self._update_meminfo_only()
        else:
            self.scan_all_files()

            if args.none:       # Do nothing, scan files only
                pass
            elif args.last:     # Update last step module summary
                self._add_all_modules()
            elif args.depends:  # Update dependencies information only
                self._add_all_dependencies()
                self._add_all_indirects()
                self._add_symbol_names()
                self._add_all_modules()
            else:               # Update oall
                self._add_all()

        self.__close_db_file()

        print("archinfo database created successfully.")

    # Add all symbols
    def _add_all_symbols(self):
        builder = SymbolsTableBuilder(self._cursor)
        builder.createTable()

        undefines_builder = UndefinedSymbolsTableBuilder(self._cursor)
        undefines_builder.createTable()

        # Add all symbols information into database
        elfFiles = self.get_all()
        idx = 1
        cnt = 0
        undefinesCnt = 0
        for elf in elfFiles:
            print("[1/5][%d|%d] add symbols for file %s" % (idx, len(elfFiles), elf["path"]))
            cnt = cnt + elf.addSymbolsToDb(builder)
            undefinesCnt = undefinesCnt + elf.addUndefinesToDb(undefines_builder)
            idx = idx + 1

        print("Add %d symbols and %d undefines for %d files finished." % (cnt, undefinesCnt, idx - 1))

        # create symbols_count view
        self._cursor.execute("drop view if exists symbols_count")
        sqlcmd = "create view IF NOT EXISTS [symbols_count] as SELECT symbols.parent_id, symbols.name, symbols.weak, symbols.version, symbols.library, symbols.demangle, symbols.id, calls_count.count from symbols LEFT OUTER join calls_count on symbols.id=calls_count.symbol_id"
        self._cursor.execute(sqlcmd)

        return cnt

    def _add_all_calls(self):
        builder = CallsTableBuilder(self._cursor)
        builder.createTable()

        # It will update calls for dependencies and matched calls for elf file
        elfFiles = self.get_all()
        idx = 1
        cnt = 0
        for elf in elfFiles:
            print("[2/5][%d|%d] add function calls for file %s" % (idx, len(elfFiles), elf["path"]))
            cnt = cnt + elf.addCallsToDb(self._cursor)
            idx = idx + 1

        print("Add %d calls %d files finished." % (cnt, idx - 1))

        self._cursor.execute("drop view if exists call_details")
        symbolCols = ["syms.%s as %s" % (c, c) for c in SymbolsTableBuilder(self._cursor).getColumns()]
        callsCols = ["calls.%s as %s" % (c, c) for c in builder.getColumns()]
        depsCols = ["deps.%s as %s" % (c, c) for c in ("external", "platformsdk", "chipsetsdk")]
        self._cursor.execute("create view [call_details] as select %s, callers.name as caller, %s, %s from calls left outer join symbols syms on syms.id = calls.symbol_id left outer join modules callers on callers.id = calls.caller_id left outer join dependencies deps on deps.id=calls.dependence_id" % (", ".join(symbolCols), ", ".join(callsCols), ", ".join(depsCols)))

        self._cursor.execute("drop view if exists calls_count")
        sqlcmd = "create view IF NOT EXISTS [calls_count] as select symbol_id, count(caller_id) as count from calls group by symbol_id"
        self._cursor.execute(sqlcmd)

        self._cursor.execute("drop view if exists calls_count_details")
        sqlcmd = "create view IF NOT EXISTS [calls_count_details] as select calls_count.symbol_id, calls_count.count, symbols.name from calls_count left outer join symbols on calls_count.symbol_id=symbols.id"
        self._cursor.execute(sqlcmd)

    def _add_symbol_names(self):
        builder = SymbolNamesTableBuilder(self._cursor)
        builder.createTable()

        print("Add symbol names now ...")
        cursor = self._cursor

        callsCnt = {}
        sqlcmd = "select name, sum(count) as calls from calls_count_details group by (name)"
        cursor.execute(sqlcmd)
        for row in cursor:
            callsCnt[row[0]] = int(row[1])
        print("Got callsCnt dict ...")

        # Update calls count
        symbolNames = []
        sqlcmd = 'select name, demangle, count(name) as count from symbols group by (name)'
        cursor.execute(sqlcmd)
        names = [description[0] for description in cursor.description]
        for row in cursor:
            symbolName = {}
            for idx, name in enumerate(names):
                symbolName[name] = row[idx]
            symbolName["calls"] = 0
            if row[0] in callsCnt:
                symbolName["calls"] = callsCnt[row[0]]
            symbolNames.append(symbolName)
        print("Updated counts for all symbols ...")

        idx = 0
        needCommit = False
        script = "BEGIN TRANSACTION;\n"
        for symbolName in symbolNames:
            sqlcmd = builder.addObjectToDb(symbolName, False)
            script = script + sqlcmd + ";\n"
            if idx > 0 and idx % 1000 == 0:
                script = script + "COMMIT;\n"
                cursor.executescript(script)
                script = "BEGIN TRANSACTION;\n"
                needCommit = False
            else:
                needCommit = True
            idx = idx + 1

        if needCommit:
            script = script + "COMMIT;\n"
            cursor.executescript(script)

        print("Updated counts for all symbols ...")

    def _add_all_dependencies(self):
        builder = DepedenceiesTableBuilder(self._cursor)
        builder.createTable()

        # Add dependency informations
        elfFiles = self.get_all()
        idx = 0
        script = "BEGIN TRANSACTION;\n"
        for elf in elfFiles:
            print("[3/5][%d|%d] add dependencies for file %s" % (idx, len(elfFiles), elf["path"]))
            for dep in elf["deps"]:
                sqlcmd = dep.addToDb(builder, False)
                script = script + sqlcmd + ";\n"
            idx = idx + 1
        script = script + "COMMIT;\n"
        self._cursor.executescript(script)

        self._cursor.execute("drop view if exists depends")
        self._cursor.execute("drop view if exists calledby")

        dependenciesCols = ["dependencies.%s as %s " % (c, c) for c in builder.getColumns()]
        self._cursor.execute("create view [depends] as select %s, callers.name as caller, callers.componentName as callerComponent, calledbys.name as callee, calledbys.componentName as calleeComponent from dependencies left outer join modules callers on callers.id = dependencies.caller_id left outer join modules calledbys on calledbys.id = dependencies.callee_id" % ", ".join(dependenciesCols))

    def _add_all_indirects(self):
        builder = IndirectsTableBuilder(self._cursor)
        builder.createTable()

        # Add indirects information
        elfFiles = self.get_all()
        idx = 0
        indirectIdx = 1
        script = "BEGIN TRANSACTION;\n"
        for elf in elfFiles:
            print("[4/5][%d|%d] add indirect dependencies for file %s" % (idx, len(elfFiles), elf["path"]))
            for mod in elf["deps_indirect"]:
                indirect = IndirectDep(indirectIdx, elf, mod)
                sqlcmd = builder.addObjectToDb(indirect, False)
                script = script + sqlcmd + ";\n"
                indirectIdx = indirectIdx + 1
            idx = idx + 1
        script = script + "COMMIT;\n"
        self._cursor.executescript(script)

        self._cursor.execute("drop view if exists indirect_details")
        indirectsCols = ["indirects.%s as %s " % (c, c) for c in builder.getColumns()]
        self._cursor.execute("create view [indirect_details] as select indirects.id as id, %s, callers.name as caller, callers.componentName as callerComponent, calledbys.name as callee, calledbys.componentName as calleeComponent from indirects left outer join modules callers on callers.id = indirects.caller_id left outer join modules calledbys on calledbys.id = indirects.callee_id" % ", ".join(indirectsCols))

    def _add_mem_info(self):
        mem_dir = os.path.join(self.get_product_out_path(), "mem")
        if not os.path.exists(mem_dir):
            print("%s does not exists, no memory information added" % mem_dir)
            return

        print("[Optional] add process memory information")
        MemDbBuilder(self._cursor, mem_dir)

        self._cursor.execute("drop view if exists modules_mem")

        memCols = ["objects.%s as %s " % (c, c) for c in ModuleTableBuilder.MEMINFO_COLS]
        self._cursor.execute("create view [modules_mem] as select modules.id as id, modules.name as name, objects.id as object_id, %s from modules LEFT OUTER JOIN objects on modules.path=objects.name" % ", ".join(memCols))

    def _update_meminfo_only(self):
        print("Update process memory information")
        self._add_mem_info()

    def _add_all_modules(self):
        self._modules_tbl_builder = ModuleTableBuilder(self._cursor)
        self._modules_tbl_builder.createTable()

        # Check if processes memory information table exists
        sqlcmd = "select tbl_name from sqlite_master where type=\"table\" and tbl_name=\"processes\""
        self._cursor.execute(sqlcmd)
        hasMemInfo = False
        for row in self._cursor:
            hasMemInfo = True
            break

        script = "BEGIN TRANSACTION;\n"
        elfFiles = self.get_all()
        idx = 1
        for elf in elfFiles:
            print("[5/5][%d|%d] add file %s" % (idx, len(elfFiles), elf["path"]))
            sqlcmd = elf.addToDb(self._modules_tbl_builder, hasMemInfo, False)
            script = script + sqlcmd + ";\n"
            idx = idx + 1
        script = script + "COMMIT;\n"
        self._cursor.executescript(script)

        self._cursor.execute("drop view if exists binaries")
        self._cursor.execute("drop view if exists libraries")

    def __close_db_file(self):
        self._conn.commit()
        self._cursor.close()
        self._conn.close()

    def _add_all(self):
        self._add_all_symbols()
        self._add_all_calls()
        self._add_all_dependencies()
        self._add_all_indirects()
        self._add_symbol_names()
        self._add_mem_info()
        self._add_all_modules()

def createArgParser():
    import argparse
    parser = argparse.ArgumentParser(description='Collect architecture information from asset files.')
    parser.add_argument('-i', '--input',
                        help='input asset files root directory', action='append', required=True)

    parser.add_argument('-o', '--output',
                        help='output architecture information database directory', required=False)

    parser.add_argument('-s', '--scan',
                        help='Update memory information only', required=False, default=False, action='store_true')

    parser.add_argument('-x', '--extract',
                        help='need extract image file system', required=False, default=False, action='store_true')

    parser.add_argument('-e', '--none',
                        help='Update nonthing, scan elf files only', required=False, default=False, action='store_true')

    parser.add_argument('-m', '--mem',
                        help='Update memory information only', required=False, default=False, action='store_true')

    parser.add_argument('-l', '--last',
                        help='Update last step module summary information only', required=False, default=False, action='store_true')

    parser.add_argument('-d', '--depends',
                        help='Update module summary information only', required=False, default=False, action='store_true')

    parser.add_argument('-p', '--ip',
                        help='Uploading IP address', required=False)

    parser.add_argument('-n', '--name',
                        help='Build product name', required=False)

    parser.add_argument('-P', '--parameter',
                        help='input parameter file from board by exec param get ', required=False)
    parser.add_argument('-b', '--bootevent',
                        help='input bootevent file from board ', required=False)
    return parser

def scan_assets_dir(args):
    topDir = os.path.realpath(args.input)
    for root, subdirs, files in os.walk(topDir):
        for dir in subdirs:
            if dir.endswith("packages") and len(root[len(topDir)+1:].split("/")) == 2:
                args.input = root
                args.output = None
                ArchInfoCollector(args)

def upload(args):
    from datetime import datetime
    from utils import command

    dst_addr = "ohos@" + args.ip

    product_name = os.path.basename(args.input)
    if args.name:
        product_name = args.name

    # get current directory
    cur_file_dir = os.path.dirname(os.path.realpath(__file__))

    # automate ssh-copy-id
    command("sshpass", "-f", os.path.join(cur_file_dir, "ohos-serverinfo.txt"), "ssh-copy-id", dst_addr)

    # build destination directory
    now = datetime.now()
    dst_dir = os.path.join("/data/ohos/files/staging", product_name, now.strftime("%Y%m%d_%H%M"))

    # mkdir for sure
    command("ssh", dst_addr, "mkdir -p '%s'" % dst_dir)

    # upload
    if not args.output:
        args.output = args.input
    print("scp %s %s:%s" % (os.path.join(args.output, "archinfo.db"), dst_addr, dst_dir))
    command("scp", os.path.join(args.output, "archinfo.db"), "%s:%s" % (dst_addr, dst_dir))

    print("archinfo.db uploaded successfully.")

def do_collect(args):
    # Scan
    if args.scan:
        scan_assets_dir(args)
    else:
        ArchInfoCollector(args)

    # for collect startup config and parameter
    files = config_parser.Collect(args)
    parameter_parser.Collect(args, files)

    if args.ip:
        upload(args)

from img_extractor import SystemImageProcessor

if __name__ == '__main__':
    parser = createArgParser()
    args = parser.parse_args()

    pro = SystemImageProcessor(args)
    pro.pre_process()

    # Scan
    try:
        do_collect(args)
    except:
        import traceback
        traceback.print_exc()
    finally:
        pro.cleanup()

    #import config
    #for prod in config.PRODUCT_ASSETS:
    #    db = ArchInfoCollector(prod)
    #db = ArchInfoCollector(config.PRODUCT_ASSETS[1])
    #create_lib_call_graph()
    #create_framework_call_graph()
    #create_all_call_graph()
    #create_js_graph()
    #create_call_report(JSItems)
    #create_call_report(SAItems)
    #create_call_report(PlatformItems)
    #create_call_report(Binaries)
    #create_call_report(HDFItems)
    #create_call_report(["appspawn"])
